// Very Fast Neural Network
//
// Programming language: Processing (www.processing.org) Java like.
// Uses the fast Walsh Hadamard transform as a fixed weight matrix
// and adjustable (parametric) activation functions as the trainable
// part. 
// Using random Gaussian initialization and an improved backpropagation
// method.
// https://sites.google.com/view/algorithmshortcuts/home
// https://archive.org/details/whtebook-archive

// Test with Lissajous curves
color c1;
int epochs;
float[][] ex;
float[] work = new float[256];
SwNet net;

void setup() {
  size(400, 400);
  net = new SwNet(256, 4, 2, .1);//vecLen=256,widthBy=4,depth=2,rate=0.1
  epochs=5; //Epochs per frame
  c1 = color(0xff, 0xd7, 00);
  ex=new float[8][256];
  for (int i = 0; i < 127; i++) {
    // Training data
    float t = (i * 2 * PI) / 127;
    ex[0][2 * i] = sin(t);
    ex[0][2 * i + 1] = sin(2 * t);
    ex[1][2 * i] = sin(2 * t);
    ex[1][2 * i + 1] = sin(t);
    ex[2][2 * i] = sin(2 * t);
    ex[2][2 * i + 1] = sin(3 * t);
    ex[3][2 * i] = sin(3 * t);
    ex[3][2 * i + 1] = sin(2 * t);
    ex[4][2 * i] = sin(3 * t);
    ex[4][2 * i + 1] = sin(4 * t);
    ex[5][2 * i] = sin(4 * t);
    ex[5][2 * i + 1] = sin(3 * t);
    ex[6][2 * i] = sin(2 * t);
    ex[6][2 * i + 1] = sin(5 * t);
    ex[7][2 * i] = sin(5 * t);
    ex[7][2 * i + 1] = sin(2 * t);
  }
  textSize(16);
}

void draw() {
  background(0);
  loadPixels();
  for (int k=0; k<epochs; k++) {
    for (int i = 0; i < ex.length; i++) {
      net.train(ex[i], ex[i]);
    }
  }
  for (int i = 0; i < 8; i++) {
    for (int j = 0; j < 255; j += 2) {
      set(25 + i * 40 + int(18 * ex[i][j]), 44 + int(18 * ex[i][j + 1]), c1);
    }
  }
  for (int i = 0; i < 8; i++) {
    net.recall(work, ex[i]);
    for (int j = 0; j < 255; j += 2) {
      set(25 + i * 40 +int( 18 * work[j]), 104 + int(18 * work[j + 1]), c1);
    }
  }
  //updatePixels();
  text("Training Data", 5, 20);
  text("Autoassociative recall", 5, 80);
  text("Iterations: " + frameCount*epochs, 5, 150);
}
public final class SwNet {
  final int vecWidth;
  final int n;
  final int pn;
  final float scale;
  final float[] params;
  final float[] surface;
  final float[] work;
  float rate;
  public SwNet(int vecWidth, int widthBy, int depth, float rate) {
    this.vecWidth = vecWidth;
    this.rate = rate;
    n = vecWidth * widthBy;
    pn = 2 * n * depth;
    params = new float[pn];
    surface = new float[pn >> 1];
    work = new float[n];
    scale = 1.0 / sqrt(n);
    for (int i = 0; i < pn; i++) {
      params[i] = randomGaussian();
    }
  }

  public void recall(float[] result, float[] input) {
    recall(result, input, false);
  }

  void recall(float[] result, float[] input, boolean sur) {
    for (int i = 0, rs = 0x77777777; i < n; i += vecWidth) {
      for (int j = 0; j < vecWidth; j++) { // copy n times with sub-random sign flips
        rs ^=rs<<2;
        rs ^=rs>>>3;
        work[i + j] = Float.intBitsToFloat(Float.floatToRawIntBits(input[j]) ^ (rs &0x80000000));
      }
    }
    whtN();
    int pIdx = 0; // parameter index
    while (pIdx < pn) {
      if (sur)
        System.arraycopy(work, 0, surface, pIdx >> 1, n);
      for (int i = 0; i < n; i++, pIdx += 2) { // parametric switching activation function
        work[i] = work[i] * params[pIdx + (Float.floatToRawIntBits(work[i]) >>> 31)];
      }
      whtN();
    }
    for (int i = 0, rs = 0x77777777; i < vecWidth; i++) { // sub-random sign flips to remove spectral bias
       rs ^=rs<<3;
       rs ^=rs>>>2;
      result[i] = Float.intBitsToFloat(Float.floatToRawIntBits(work[i]) ^ (rs &0x80000000));
    }
  }

  public float train(float[] target, float[] input) {
    float cost = 0f;
    recall(work, input, true);
    for (int i = 0, rs = 0x77777777; i < vecWidth; i++) {
      float e = target[i] - work[i];
      cost += e * e;
      e *= rate;
      rs ^=rs<<3;
      rs ^=rs>>>2;
      work[i] = Float.intBitsToFloat(Float.floatToRawIntBits(e) ^ (rs &0x80000000));
    }
    java.util.Arrays.fill(work, vecWidth, n, 0.0); // clear upper 3/4
    int pIdx = pn - 2;
    while (pIdx > 0) {
      whtN(); // error vector in work
      for (int i = n - 1; i >= 0; i--, pIdx -= 2) {
        float x = surface[pIdx >> 1];
        int signBit = Float.floatToRawIntBits(x) >>> 31;
        float w = params[pIdx + signBit] + x * work[i];
        params[pIdx + signBit] = w;
        work[i] *=w;
      }
    }
    return cost;
  }

  // Fast Walsh Hadamard transform Normalized (scaled)
  void whtN() {
    for (int i = 0; i < n; i += 16) {
      float a = work[i] * scale;
      float b = work[i + 1] * scale;
      float c = work[i + 2] * scale;
      float d = work[i + 3] * scale;
      float e = work[i + 4] * scale;
      float f = work[i + 5] * scale;
      float g = work[i + 6] * scale;
      float h = work[i + 7] * scale;

      float au = work[i + 8] * scale;
      float bu = work[i + 9] * scale;
      float cu = work[i + 10] * scale;
      float du = work[i + 11] * scale;
      float eu = work[i + 12] * scale;
      float fu = work[i + 13] * scale;
      float gu = work[i + 14] * scale;
      float hu = work[i + 15] * scale;

      float t = a;
      a = a + b;
      b = t - b;
      t = c;
      c = c + d;
      d = t - d;
      t = e;
      e = e + f;
      f = t - f;
      t = g;
      g = g + h;
      h = t - h;

      t = a;
      a = a + c;
      c = t - c;
      t = b;
      b = b + d;
      d = t - d;
      t = e;
      e = e + g;
      g = t - g;
      t = f;
      f = f + h;
      h = t - h;

      t = a;
      a = a + e;
      e = t - e;
      t = b;
      b = b + f;
      f = t - f;
      t = c;
      c = c + g;
      g = t - g;
      t = d;
      d = d + h;
      h = t - h;

      //
      t = au;
      au = au + bu;
      bu = t - bu;
      t = cu;
      cu = cu + du;
      du = t - du;
      t = eu;
      eu = eu + fu;
      fu = t - fu;
      t = gu;
      gu = gu + hu;
      hu = t - hu;

      t = au;
      au = au + cu;
      cu = t - cu;
      t = bu;
      bu = bu + du;
      du = t - du;
      t = eu;
      eu = eu + gu;
      gu = t - gu;
      t = fu;
      fu = fu + hu;
      hu = t - hu;

      t = au;
      au = au + eu;
      eu = t - eu;
      t = bu;
      bu = bu + fu;
      fu = t - fu;
      t = cu;
      cu = cu + gu;
      gu = t - gu;
      t = du;
      du = du + hu;
      hu = t - hu;

      work[i] = a + au;
      work[i + 1] = b + bu;
      work[i + 2] = c + cu;
      work[i + 3] = d + du;
      work[i + 4] = e + eu;
      work[i + 5] = f + fu;
      work[i + 6] = g + gu;
      work[i + 7] = h + hu;

      work[i + 8] = a - au;
      work[i + 9] = b - bu;
      work[i + 10] = c - cu;
      work[i + 11] = d - du;
      work[i + 12] = e - eu;
      work[i + 13] = f - fu;
      work[i + 14] = g - gu;
      work[i + 15] = h - hu;
    }
    int hs = 16;
    while (hs < n) {
      int i = 0;
      while (i < n) {
        final int j = i + hs;
        while (i < j) {
          float a = work[i];
          float b = work[i + hs];
          work[i] = a + b;
          work[i + hs] = a - b;
          i += 1;
        }
        i += hs;
      }
      hs += hs;
    }
  }
}

